#!/bin/bash

: "${BZ:?BZ not set}"
: "${GPU_NUM:?GPU_NUM not set}"

srun --ntasks="${SLURM_JOB_NUM_NODES}" bash -c "echo -n 'Clearing cache on ' && hostname && sync && sudo /sbin/sysctl vm.drop_caches=3"

if [ $GPU_NUM == 1 ];then
    srun --ntasks=1 --container-image="${CONT}" --container-mounts="${MOUNTS}" bash -cx " \
        cd /workdir/test/sok_perf_test/dlrm;
        python3 main.py \
            --global_batch_size=${BZ} \
            --train_file_pattern=\"/dataset/train/*.csv\" \
            --test_file_pattern=\"/dataset/test/*.csv\" \
            --embedding_layer=\"SOK\" \
            --embedding_vec_size=32 \
            --bottom_stack 512 256 32 \
            --top_stack 1024 1024 512 256 1 \
            --distribute_strategy=\"mirrored\" \
            --gpu_num=1"
else
    srun --ntasks=1 --container-image="${CONT}" --container-mounts="${MOUNTS}" bash -cx " \
    cd /workdir/test/sok_perf_test/dlrm;
    mpirun -np ${GPU_NUM} --allow-run-as-root --oversubscribe \
    python3 main.py \
        --global_batch_size=${BZ} \
        --train_file_pattern=\"/dataset/train/*.csv\" \
        --test_file_pattern=\"/dataset/test/*.csv\" \
        --embedding_layer=\"SOK\" \
        --embedding_vec_size=32 \
        --bottom_stack 512 256 32 \
        --top_stack 1024 1024 512 256 1 \
        --distribute_strategy=\"multiworker\""
fi